clear all; close all; clc

% Before Running, please read it carefully:
% 1. On line 110, add full path to access county-level_data (e.g, C:\Users\...\Reproducible_Version\state-level_data')

% 2. In the folder Reproducible_Version, there will be two subfolders:
% result_figures_reproducible and result_files_reproducible. These are
% empty folders that will store the results if you run this code. You can access the results already generated in
% the folders Results \ result_figures and result_files.


% Load vector with state names and weeks
load vec_state.mat;

% Load weeks
table_weeks = readtable('weeks_new_table.xlsx');
weeks       = char(table2array(table_weeks));
clear table_weeks

% Load Rt data and event weeks
table_events = readtable('events_omicron_shifted_state_level.xlsx');

% % Vector with digital/predictor variable names
vec_var1{1} = 'JHU_cases';

%% Choose variables and parameters
EW_tol_weeks = 6;                   % number of weeks we tolerate to call it at EW... digital activation preceding event with more than the tolerance is considered a false alarm...
alphamin     = 1;                   % choose minimum number of positive derivative events for the loop
strategy     = 'fast';              % choose strategy fast vs slow: fast picks up the minimum optimal alpha and threshold values (bottom-left of a rectangle), while slow picks up the maximum optimal values
window_gap   = 5;                   % max gap in weeks
vec_bg       = 1;%[0.1:0.05:0.9];      % define vector of percentage threshold values for digital data

vec_alpha  =  1;%alphamin:window_gap; % tolerance for alpha>1. If time window has tol_alpha or more points of increase (alpha>1), mark as an event.

%% CODE STARTS HERE: LOOP ON DIGITAL/PREDICTOR VARIABLES STARTS HERE

for ind_var1 = 1:length(vec_var1)
    var1     = vec_var1{ind_var1};
    
    % if statement to change final week if signal is up2date or twitter (state) - due to lack of
    % data
    if strcmp('up2date',var1)
        week_final = 61;
    else
        week_final = length(weeks);
    end
    
    %%  START LOOP ON COUNTIES
    for ind_state = 1:length(vec_state)
        ind_state, var1
        
        % Find first week of training set
        startweek_train             = 1;
        
        % Find events for the state
        aux_ind_state              = contains(table_events{:,2},vec_state(ind_state));
        table_selected_events       = table_events(aux_ind_state,:);
        
        % start loop on training waves
        for ind_last_train_wave = 1:size(table_selected_events,1)-1 % if the given state had N waves, we will train from wave 1 to X, where X=1,2,...N-1 and test on wave X+1
            
            % week index of wave 1 come from coris Rt activation
            aux_ind_wave1                      = datetime(weeks) == table_selected_events{1,4};
            week_wave1                         = find(aux_ind_wave1);
            
            % Start vector of training waves
            vec_week_train_event               = week_wave1; % will input in fun train
            
            % week index of next train waves
            for ind_train_wave = 2:ind_last_train_wave
                aux_ind_wave                      = datetime(weeks)==table_selected_events{ind_train_wave,4}; % column index 4 indicates starting dates for waves
                week_wave                         = find(aux_ind_wave);
                vec_week_train_event              = [vec_week_train_event week_wave]; % build vector of training waves
            end
            
            % Find last week of training set
            aux_ind_endweek_train              = datetime(weeks)==table_selected_events{ind_last_train_wave,5}; % colum index 5 indicates end dates for waves
            endweek_train                      = find(aux_ind_endweek_train);
            
            % Find index of test week
            aux_ind_week_test                  = datetime(weeks)==table_selected_events{ind_last_train_wave+1,4}; %
            vec_week_test_event                = find(aux_ind_week_test);
            
            % Find index of last week of test wave
            aux_ind_lastweek_test                  = datetime(weeks)==table_selected_events{ind_last_train_wave+1,5}; %
            endweek_test                           = find(aux_ind_lastweek_test);
            
            
            %% BUILD TIME WINDOWS FOR TRAINING AND TEST DATA SETS
            
            % Define training windows vector
            vec_training_time = startweek_train:endweek_train;             % define training weeks
            vec_test_time     = endweek_train+1:week_final;                % define test weeks
            
            % Define sliding windows for training dataset
            for j0=1:length(vec_training_time)-window_gap
                windows_training{j0} = vec_training_time(j0):vec_training_time(j0) + window_gap;
            end
            
            % Define sliding windows for test dataset
            for jj0=1:length(vec_test_time)-window_gap
                windows_test{jj0} = vec_test_time(jj0):vec_test_time(jj0) + window_gap;
            end
            
            % If we can generate training/test windows and if the test
            % events fall into the available dataset, proceed
            if exist('windows_training','var') == 1 && exist('windows_test','var') == 1 && week_final >= vec_week_test_event
                
                %% Load state data
                folder                 = 'C:\Users\ch222818\Dropbox\Reproducible_Version_State_Level\data';
                fileName_state         = [vec_state{ind_state} '_preprocessed.csv']; 
                fullFileName           = fullfile(folder, fileName_state);
                
                table_data             = readtable(fullFileName);                       % read table of population data
                var1_timeseries        = table2array(table_data(1:week_final,var1));                  % get double array for digital data var1.
                
                
                clear aux_ind_state aux_ind_week_test aux_ind_lambda  aux_ind_state aux_ind_week_train  fileName1 fileName2 fileName_state j0 jj0
                
                %% part 1: TRAINING STEP - FIXED EVENT AND CHOOSE OPTIMAL THRESHOLD
                
                % Call training function
                [matrix_ACC_train,matrix_PPV_train,matrix_NPV_train,matrix_TP_train,matrix_FP_train,matrix_TN_train,matrix_FN_train,activation_weeks_var1_train,activation_weeks_var2_train] = fun_train(vec_training_time,windows_training,vec_bg,var1_timeseries,vec_week_train_event,vec_alpha);
                
                
                if ~isempty(matrix_ACC_train)& find(~isnan(matrix_PPV_train)==1)>0 % proceed if there was a selection window and at least one TP or FP across training simulations (otherwise PPV matrix would be solely of NaN given by 0/0 division - we discard that case)
                    
                    
                    % Choose optimal thresholds
                    ind_PPV = matrix_PPV_train == max(max(matrix_PPV_train)); % get indexes where PPV is max
                    ind_NPV = matrix_NPV_train == max(max(matrix_NPV_train)); % get indexes where NPV is max
                    ind_ACC = matrix_ACC_train == max(max(matrix_ACC_train)); % get indexes where ACC is max
                    
                    ind_sum             = ind_PPV + ind_NPV + ind_ACC;    % find indexes where ACC, PPV and NPV maximum are achieved. Entries where ind_summ is equal to 3.
                    [ind_max1,ind_max2] = find(ind_sum==3);               % find indexes where maximum for ACC, PPV and NPV are achieved. This command will yield a list of x and y coordinates.
                    ind_max_sum         = ind_max1 + ind_max2;            % Since there might be several candidates, choose index for which the sum of indexes is minimum. This is one option, others might be possible as well
                    
                    if contains(strategy, 'fast')
                        [~,ww]               = min(ind_max_sum);               % get location where minimum is achieved
                    else if contains(strategy, 'slow')
                            [~,ww]              = max(ind_max_sum);              % get location where maximum is achieved
                        end
                    end
                    
                    if ~isempty(ww)
                        w_sig                    = ~isempty(ww); % w_sig=1 means ACC, PPV and NPV were selected for training.
                        alpha_opt                = vec_alpha(ind_max1(ww(1)));    % pick corresponding bg value. If there are more than 1 ww value (two parameter choices with the same minimum/maximum sum), pick the first.
                        bg_opt                   = vec_bg(ind_max2(ww(1)));       % pick corresponding bt value. If there are more than 1 ww value (two parameter choices with the same minimum/maximum sum), pick the first.
                        
                    else if isempty(ww)
                            w_sig                  = ~isempty(ww);               % w_sig=0 means ACC, PPV and NPV were not selected for training.
                            [ind_max1,ind_max2]    = find(ind_PPV==1);           % if there are no bg and bt maximizing ACC, NPV and PPV simultaneously, find those that only maximize PPV
                            ind_max_sum            = ind_max1 + ind_max2;        % Since there might be several candidates, choose index for which the sum of indexes is minimum. This is one option, others might be possible as well
                            
                            if contains(strategy, 'fast')
                                [~,ww]               = min(ind_max_sum);               % get location where minimum is achieved
                            else if contains(strategy, 'slow')
                                    [~,ww]              = max(ind_max_sum);              % get location where maximum is achieved
                                end
                            end
                            alpha_opt                = vec_alpha(ind_max1(ww(1)));    % pick corresponding bg value. If there are more than 1 ww value (two parameter choices with the same minimum/maxmimum sum), pick the first.
                            bg_opt                   = vec_bg(ind_max2(ww(1)));       % pick corresponding bt value. If there are more than 1 ww value (two parameter choices with the same minimum/maximum sum), pick the first.
                            
                        end
                    end
                    bg_opt, alpha_opt
                    %% PART 2: Test on second wave
                    
                    [select_windows_th_act,select_windows_deriv_act,windows_TP_test,windows_FP_test,windows_TN_test,windows_FN_test,num_TP_test, num_FP_test, num_TN_test, num_FN_test,ACC_test,PPV_test,NPV_test,activation_weeks_var1_test,activation_weeks_var2_test] = fun_test(alpha_opt,bg_opt,vec_training_time,windows_test,var1_timeseries,vec_week_test_event);
                    
                    clear ind_ACC ww ind_sum ind_max1 ind_max2 ind_max_sum ind_NPV ind_PPV ind_train_wave
                    
                    %% PLOT RESULTS
                    h= figure(1);
                    set(h, 'Visible', 'off')
                    subplot(2,1,1)
                    plot(var1_timeseries,'b.-'); hold on
                    xline(endweek_train)
                    line([endweek_train week_final],[bg_opt*max(var1_timeseries(1:endweek_train)) bg_opt*max(var1_timeseries(1:endweek_train))],'Color','k')
                    ylabel([var1])
                    
                    for i1=1:length(activation_weeks_var1_test)
                        aux_activation_weeks_var1_test(i1)= activation_weeks_var1_test{i1};
                    end
                    unique_activation_weeks_var1_test = unique(aux_activation_weeks_var1_test);
                    
                    if ~isnan(unique_activation_weeks_var1_test)
                    plot(intersect(startweek_train:endweek_test,unique_activation_weeks_var1_test),var1_timeseries(intersect(startweek_train:endweek_test,unique_activation_weeks_var1_test)),'bo');
                    end
                    title(['alpha_{opt} =  ' num2str(alpha_opt) '. First out of sample activation: week ' num2str(unique_activation_weeks_var1_test(1)) '.'])
                    
                    subplot(2,1,2)
                    cases_timeseries = table2array(table_data(1:week_final,'JHU_cases'));
                    plot(cases_timeseries,'r.-'); hold on
                    title(['Rt activation: week ' num2str(vec_week_test_event) '.'])
                    xline(endweek_train)
                    ylabel(['JHU cases'])
                    
                    if vec_week_test_event<=week_final
                        plot(vec_week_test_event,cases_timeseries(vec_week_test_event),'rs'); hold on
                    end
                    
                    sgtitle([ vec_state{ind_state} ': training up to wave ' num2str(ind_last_train_wave) ' and testing on wave ' num2str(ind_last_train_wave+1) '.'],'Interpreter','none')
                    xlabel('weeks')
                    
                    %% SAVE FILES AND FIGURES
                    figurename_fig = ['/result_figures_reproducible/fig/Figure_omicron_shifted_parag_prob_state_level_training_upto_wave_' num2str(ind_last_train_wave) '_test_wave_' num2str(ind_last_train_wave+1) '_alphamin_' num2str(alphamin) '_strategy_' strategy '_signal_' var1 '_ind_state_' num2str(ind_state) '.fig'];
                    figurename_png = ['/result_figures_reproducible/png/Figure_omicron_shifted_parag_prob_state_level_training_upto_wave_' num2str(ind_last_train_wave) '_test_wave_' num2str(ind_last_train_wave+1) '_alphamin_' num2str(alphamin) '_strategy_' strategy '_signal_' var1 '_ind_state_' num2str(ind_state) '.png'];
                    
                    
                    saveas(h,[pwd figurename_fig])
                    saveas(h,[pwd figurename_png])
                    
                    
                    close all % close figure after saving
                    
                    % File name in mat format
                    filename_mat = ['result_files_reproducible/Results_omicron_shifted_parag_prob_state_level_training_upto_wave_' num2str(ind_last_train_wave) '_test_wave_' num2str(ind_last_train_wave+1) '_alphamin_' num2str(alphamin) '_strategy_' strategy '_signal_' var1 '_ind_state_' num2str(ind_state) '.mat'];
                    
                    % Save file with results
                    save(filename_mat)
                end % end of if statement
            end     % end of if statement
            
            clearvars -except table_events table_pop table_Rt var1 vec_alpha vec_bg vec_state vec_var1 week_final weeks window_gap strategy alphamin ind_state table_selected_events  startweek_train EW_tol_weeks
        end % end loop on waves
    end % end loop on counties
    clear var1 week_final
    
end % end loop on predictor